{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import math\n",
    "import timeit\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train data shape:  (49000, 32, 32, 3)\n",
      "Train labels shape:  (49000,) int32\n",
      "Validation data shape:  (1000, 32, 32, 3)\n",
      "Validation labels shape:  (1000,)\n",
      "Test data shape:  (10000, 32, 32, 3)\n",
      "Test labels shape:  (10000,)\n"
     ]
    }
   ],
   "source": [
    "def load_cifar10(num_training=49000, num_validation=1000, num_test=10000):\n",
    "    cifar10 = tf.keras.datasets.cifar10.load_data()\n",
    "    (X_train, y_train), (X_test, y_test) = cifar10\n",
    "    X_train = np.asarray(X_train, dtype=np.float32)\n",
    "    y_train = np.asarray(y_train, dtype=np.int32).flatten()\n",
    "    X_test = np.asarray(X_test, dtype=np.float32)\n",
    "    y_test = np.asarray(y_test, dtype=np.int32).flatten()\n",
    "\n",
    "    mask = range(num_training, num_training + num_validation)\n",
    "    X_val = X_train[mask]\n",
    "    y_val = y_train[mask]\n",
    "    mask = range(num_training)\n",
    "    X_train = X_train[mask]\n",
    "    y_train = y_train[mask]\n",
    "    mask = range(num_test)\n",
    "    X_test = X_test[mask]\n",
    "    y_test = y_test[mask]\n",
    "    \n",
    "    mean_pixel = X_train.mean(axis=(0, 1, 2), keepdims=True)\n",
    "    std_pixel = X_train.std(axis=(0, 1, 2), keepdims=True)\n",
    "    X_train = (X_train - mean_pixel) / std_pixel\n",
    "    X_val = (X_val - mean_pixel) / std_pixel\n",
    "    X_test = (X_test - mean_pixel) / std_pixel\n",
    "\n",
    "    return X_train, y_train, X_val, y_val, X_test, y_test\n",
    "\n",
    "NHW = (0, 1, 2)\n",
    "X_train, y_train, X_val, y_val, X_test, y_test = load_cifar10()\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', y_train.shape, y_train.dtype)\n",
    "print('Validation data shape: ', X_val.shape)\n",
    "print('Validation labels shape: ', y_val.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Dataset(object):\n",
    "    def __init__(self, x, y, batch_size, shuffle=True):\n",
    "        assert x.shape[0] == y.shape[0], 'Got different numbers of data and labels'\n",
    "        self.x = x \n",
    "        self.y = y\n",
    "        self.batch_size = batch_size\n",
    "        self.shuffle = shuffle\n",
    "        \n",
    "    def __iter__(self):\n",
    "        N = self.x.shape[0]\n",
    "        B = self.batch_size\n",
    "        idx = np.arange(N)\n",
    "        if self.shuffle:\n",
    "            np.random.shuffle(idx)\n",
    "        return iter((self.x[i:i+B], self.y[i:i+B]) for i in range(0, N, B))\n",
    "    \n",
    "train_dset = Dataset(X_train, y_train, batch_size=64, shuffle=True)\n",
    "val_dset = Dataset(X_val, y_val, batch_size=64, shuffle=False)\n",
    "test_dset = Dataset(X_test, y_test, batch_size=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 (64, 32, 32, 3) (64,)\n",
      "[6 9 9 4 1 1 2 7 8 3 4 7 7 2 9 9 9 3 2 6 4 3 6 6 2 6 3 5 4 0 0 9 1 3 4 0 3\n",
      " 7 3 3 5 2 2 7 1 1 1 2 2 0 9 5 7 9 2 2 5 2 4 3 1 1 8 2]\n",
      "1 (64, 32, 32, 3) (64,)\n",
      "[1 1 4 9 7 8 5 9 6 7 3 1 9 0 3 1 3 5 4 5 7 7 4 7 9 4 2 3 8 0 1 6 1 1 4 1 8\n",
      " 3 9 6 6 1 8 5 2 9 9 8 1 7 7 0 0 6 9 1 2 2 9 2 6 6 1 9]\n",
      "2 (64, 32, 32, 3) (64,)\n",
      "[5 0 4 7 6 7 1 8 1 1 2 8 1 3 3 6 2 4 9 9 5 4 3 6 7 4 6 8 5 5 4 3 1 8 4 7 6\n",
      " 0 9 5 1 3 8 2 7 5 3 4 1 5 7 0 4 7 5 5 1 0 9 6 9 0 8 7]\n",
      "3 (64, 32, 32, 3) (64,)\n",
      "[8 8 2 5 2 3 5 0 6 1 9 3 6 9 1 3 9 6 6 7 1 0 9 5 8 5 2 9 0 8 8 0 6 9 1 1 6\n",
      " 3 7 6 6 0 6 6 1 7 1 5 8 3 6 6 8 6 8 4 6 6 1 3 8 3 4 1]\n",
      "4 (64, 32, 32, 3) (64,)\n",
      "[7 1 3 8 5 1 1 4 0 9 3 7 4 9 9 2 4 9 9 1 0 5 9 0 8 2 1 2 0 5 6 3 2 7 8 8 6\n",
      " 0 7 9 4 5 6 4 2 1 1 2 1 5 9 9 0 8 4 1 1 6 3 3 9 0 7 9]\n",
      "5 (64, 32, 32, 3) (64,)\n",
      "[7 7 9 1 5 1 6 6 8 7 1 3 0 3 3 2 4 5 7 5 9 0 3 4 0 4 4 6 0 0 6 6 0 8 1 6 2\n",
      " 9 2 5 9 6 7 4 1 8 7 3 6 9 3 0 4 0 5 1 0 3 4 8 5 4 7 2]\n",
      "6 (64, 32, 32, 3) (64,)\n",
      "[3 9 7 6 7 1 4 7 0 1 7 3 1 8 4 4 2 0 2 2 0 0 9 0 9 6 8 2 7 7 4 0 3 0 8 9 4\n",
      " 2 7 2 5 2 5 1 9 4 8 5 1 7 4 4 0 6 9 0 7 8 8 9 9 3 3 4]\n"
     ]
    }
   ],
   "source": [
    "for t, (x, y) in enumerate(train_dset):\n",
    "    print(t, x.shape, y.shape)\n",
    "    print(y)\n",
    "    if t > 5: break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device:  /cpu:0\n"
     ]
    }
   ],
   "source": [
    "USE_GPU = False\n",
    "print_every = 100\n",
    "\n",
    "if USE_GPU:\n",
    "    device = '/device:GPU:0'\n",
    "else:\n",
    "    device = '/cpu:0'\n",
    "\n",
    "print('Using device: ', device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Barebone TensorFlow\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "def flatten(x):\n",
    "    N = tf.shape(x)[0]\n",
    "    return tf.reshape(x, (N, -1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x:  <class 'tensorflow.python.framework.ops.Tensor'> Tensor(\"Placeholder:0\", dtype=float32, device=/device:CPU:0)\n",
      "x_flat:  <class 'tensorflow.python.framework.ops.Tensor'> Tensor(\"Reshape:0\", shape=(?, ?), dtype=float32, device=/device:CPU:0)\n",
      "\n",
      "x_np:\n",
      " (2, 3, 4) \n",
      "\n",
      "x_np_flat:\n",
      " (2, 12) \n",
      "\n",
      "x_np:\n",
      " (2, 3, 2) \n",
      "\n",
      "x_flat_np:\n",
      " (2, 6)\n"
     ]
    }
   ],
   "source": [
    "def test_flatten():\n",
    "    tf.reset_default_graph()\n",
    "    with tf.device(device):\n",
    "        x = tf.placeholder(tf.float32)\n",
    "        x_flat = flatten(x)\n",
    "    print('x: ', type(x), x)\n",
    "    print('x_flat: ', type(x_flat), x_flat)\n",
    "    print()\n",
    "    \n",
    "    with tf.Session() as sess:\n",
    "        x_np = np.arange(24).reshape((2, 3, 4))\n",
    "        print('x_np:\\n', x_np.shape, '\\n')\n",
    "        x_np_flat = sess.run(x_flat, feed_dict={x: x_np})\n",
    "        print('x_np_flat:\\n', x_np_flat.shape, '\\n')\n",
    "        \n",
    "        x_np = np.arange(12).reshape((2, 3, 2))\n",
    "        print('x_np:\\n', x_np.shape, '\\n')\n",
    "        x_flat_np = sess.run(x_flat, feed_dict={x: x_np})\n",
    "        print('x_flat_np:\\n', x_flat_np.shape)\n",
    "test_flatten()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Two-Layer Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "def two_layer_fc(x, params):\n",
    "    w1, w2 = params\n",
    "    x = flatten(x)\n",
    "    h = tf.nn.relu(tf.matmul(x, w1))\n",
    "    scores = tf.matmul(h, w2)\n",
    "    return scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(64, 10)\n"
     ]
    }
   ],
   "source": [
    "def two_layer_fc_test():\n",
    "    tf.reset_default_graph()\n",
    "    hidden_layer_size = 42\n",
    "    \n",
    "    with tf.device(device):\n",
    "        x = tf.placeholder(tf.float64)\n",
    "        w1 = tf.zeros((3*32*32, hidden_layer_size), dtype=tf.float64)\n",
    "        w2 = tf.zeros((hidden_layer_size, 10), dtype=tf.float64)\n",
    "        scores = two_layer_fc(x, (w1, w2))\n",
    "    x_np = np.zeros((64, 32, 32, 3))\n",
    "    with tf.Session() as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        scores_np = sess.run(scores, feed_dict={x: x_np})\n",
    "        print(scores_np.shape)\n",
    "        \n",
    "two_layer_fc_test()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Three-Layer ConvNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "def three_layer_convnet(x, params):\n",
    "    (w1, b1, w2, b2, w3, b3) = params\n",
    "    out_1 = tf.nn.relu(tf.nn.conv2d(x, w1, strides=[1, 1, 1, 1], padding='SAME') + b1)\n",
    "    out_2 = tf.nn.relu(tf.nn.conv2d(out_1, w2, strides=[1, 1, 1, 1], padding='SAME') + b2)\n",
    "    scores = tf.matmul(flatten(out_2), w3) + b3\n",
    "    return scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "scores_np has shape:  (64, 10)\n"
     ]
    }
   ],
   "source": [
    "def three_layer_convnet_test():\n",
    "    tf.reset_default_graph()\n",
    "    with tf.device(device):\n",
    "        x = tf.placeholder(tf.float32)\n",
    "        w1 = tf.zeros((5, 5, 3, 6))\n",
    "        b1 = tf.zeros((6, ))\n",
    "        w2 = tf.zeros((3, 3, 6, 9))\n",
    "        b2 = tf.zeros((9, ))\n",
    "        w3 = tf.zeros((9*32*32, 10))\n",
    "        b3 = tf.zeros((10, ))\n",
    "        params = (w1, b1, w2, b2, w3, b3)\n",
    "        scores = three_layer_convnet(x, params)\n",
    "    \n",
    "    x_np = np.zeros((64, 32, 32, 3))\n",
    "    with tf.Session() as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        scores_np = sess.run(scores, feed_dict={x: x_np})\n",
    "        print('scores_np has shape: ', scores_np.shape)\n",
    "        \n",
    "with tf.device('/cpu:0'):\n",
    "    three_layer_convnet_test()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training Step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "def training_step(scores, y, params, learning_rate):\n",
    "    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=scores)\n",
    "    loss = tf.reduce_mean(losses)\n",
    "    \n",
    "    grad_params = tf.gradients(loss, params)\n",
    "    new_weights = []\n",
    "    for w, grad_w in zip(params, grad_params):\n",
    "        new_w = tf.assign_sub(w, learning_rate * grad_w)\n",
    "        new_weights.append(new_w)\n",
    "        \n",
    "    with tf.control_dependencies(new_weights):\n",
    "        return tf.identity(loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model_fn, init_fn, learning_rate):\n",
    "    tf.reset_default_graph()\n",
    "    is_training = tf.placeholder(tf.bool, name='is_training')\n",
    "    with tf.device(device):\n",
    "        x = tf.placeholder(tf.float32, [None, 32, 32, 3])\n",
    "        y = tf.placeholder(tf.int32, [None])\n",
    "        params = init_fn()\n",
    "        scores = model_fn(x, params)\n",
    "        loss = training_step(scores, y, params, learning_rate)\n",
    "        \n",
    "    with tf.Session() as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        for t, (x_np, y_np) in enumerate(train_dset):\n",
    "            feed_dict={x: x_np, y: y_np}\n",
    "            loss_np = sess.run(loss, feed_dict=feed_dict)\n",
    "            if t % print_every == 0:\n",
    "                print('Iteration %d, loss = %.4f' % (t, loss_np))\n",
    "                check_accuracy(sess, val_dset, x, scores, is_training)            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_accuracy(sess, dset, x, scores, is_training=None):\n",
    "    num_correct, num_samples = 0, 0\n",
    "    for x_batch, y_batch in dset:\n",
    "        feed_dict = {x: x_batch, is_training: 0}\n",
    "        scores_np = sess.run(scores, feed_dict=feed_dict)\n",
    "        y_pred = scores_np.argmax(axis=1)\n",
    "        num_samples += x_batch.shape[0]\n",
    "        num_correct += (y_pred == y_batch).sum()\n",
    "    acc = float(num_correct) / num_samples\n",
    "    print('Got %d / %d correct (%.2f%%)' % (num_correct, num_samples, 100 * acc))        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "def kaiming_normal(shape):\n",
    "    if len(shape) == 2:\n",
    "        fan_in, fan_out = shape[0], shape[1]\n",
    "    elif len(shape) == 4:\n",
    "        fan_in, fan_out = np.prod(shape[:3]), shape[3]\n",
    "    return tf.random_normal(shape) * np.sqrt(2.0 / fan_in)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train a Two-Layer Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration 0, loss = 3.2512\n",
      "Got 117 / 1000 correct (11.70%)\n",
      "Iteration 100, loss = 2.0240\n",
      "Got 382 / 1000 correct (38.20%)\n",
      "Iteration 200, loss = 1.4809\n",
      "Got 392 / 1000 correct (39.20%)\n",
      "Iteration 300, loss = 1.8005\n",
      "Got 376 / 1000 correct (37.60%)\n",
      "Iteration 400, loss = 1.8020\n",
      "Got 420 / 1000 correct (42.00%)\n",
      "Iteration 500, loss = 1.7362\n",
      "Got 432 / 1000 correct (43.20%)\n",
      "Iteration 600, loss = 1.8651\n",
      "Got 414 / 1000 correct (41.40%)\n",
      "Iteration 700, loss = 1.9470\n",
      "Got 437 / 1000 correct (43.70%)\n"
     ]
    }
   ],
   "source": [
    "def two_layer_fc_init():\n",
    "    hidden_layer_size = 4000\n",
    "    w1 = tf.Variable(kaiming_normal((3*32*32, 4000)))\n",
    "    w2 = tf.Variable(kaiming_normal((4000, 10)))\n",
    "    return [w1, w2]\n",
    "\n",
    "learning_rate = 1e-2\n",
    "train(two_layer_fc, two_layer_fc_init, learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train a three-layer ConvNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration 0, loss = 2.8190\n",
      "Got 105 / 1000 correct (10.50%)\n",
      "Iteration 100, loss = 1.8904\n",
      "Got 362 / 1000 correct (36.20%)\n",
      "Iteration 200, loss = 1.5929\n",
      "Got 394 / 1000 correct (39.40%)\n",
      "Iteration 300, loss = 1.7238\n",
      "Got 377 / 1000 correct (37.70%)\n",
      "Iteration 400, loss = 1.6944\n",
      "Got 416 / 1000 correct (41.60%)\n",
      "Iteration 500, loss = 1.6929\n",
      "Got 421 / 1000 correct (42.10%)\n",
      "Iteration 600, loss = 1.6278\n",
      "Got 459 / 1000 correct (45.90%)\n",
      "Iteration 700, loss = 1.6345\n",
      "Got 450 / 1000 correct (45.00%)\n"
     ]
    }
   ],
   "source": [
    "def three_layer_convnet_init():\n",
    "    w1 = tf.Variable(kaiming_normal((5, 5, 3, 32)))\n",
    "    b1 = tf.Variable(tf.zeros((32, )))\n",
    "    w2 = tf.Variable(kaiming_normal((3, 3, 32, 16)))\n",
    "    b2 = tf.Variable(tf.zeros((16, )))\n",
    "    w3 = tf.Variable(kaiming_normal((16*32*32, 10)))\n",
    "    b3 = tf.Variable(tf.zeros((10, )))\n",
    "    params = (w1, b1, w2, b2, w3, b3)\n",
    "    \n",
    "    return params\n",
    "\n",
    "learning_rate = 3e-3\n",
    "train(three_layer_convnet, three_layer_convnet_init, learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Keras Model API\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Two-Layer Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(64, 10)\n"
     ]
    }
   ],
   "source": [
    "class TwoLayerFC(tf.keras.Model):\n",
    "    def __init__(self, hidden_size, num_classes):\n",
    "        super().__init__()\n",
    "        initializer = tf.variance_scaling_initializer(scale=2.0)\n",
    "        self.fc1 = tf.layers.Dense(hidden_size, activation=tf.nn.relu, kernel_initializer=initializer)\n",
    "        self.fc2 = tf.layers.Dense(num_classes, kernel_initializer=initializer)\n",
    "    def call(self, x, training=None):\n",
    "        x = tf.layers.flatten(x)\n",
    "        x = self.fc1(x)\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "    \n",
    "def test_TwoLayerFC():\n",
    "    tf.reset_default_graph()\n",
    "    input_size, hidden_size, num_classes = 50, 42, 10\n",
    "    model = TwoLayerFC(hidden_size, num_classes)\n",
    "    with tf.device(device):\n",
    "        x = tf.zeros((64, input_size))\n",
    "        scores = model(x)\n",
    "    with tf.Session() as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        scores_np = sess.run(scores)\n",
    "        print(scores_np.shape)\n",
    "        \n",
    "test_TwoLayerFC()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(64, 10)\n"
     ]
    }
   ],
   "source": [
    "def two_layer_fc_functional(inputs, hidden_size, num_classes):\n",
    "    initializer = tf.variance_scaling_initializer(scale=2.0)\n",
    "    flatten_inputs = tf.layers.flatten(inputs)\n",
    "    fc1_out = tf.layers.dense(flatten_inputs, hidden_size, activation=tf.nn.relu, kernel_initializer=initializer)\n",
    "    scores = tf.layers.dense(fc1_out, num_classes, kernel_initializer=initializer)\n",
    "    return scores\n",
    "\n",
    "def test_two_layer_fc_functional():\n",
    "    tf.reset_default_graph()\n",
    "    input_size, hidden_size, num_classes = 50, 42, 10\n",
    "    with tf.device(device):\n",
    "        x = tf.zeros((64, input_size))\n",
    "        scores = two_layer_fc_functional(x, hidden_size, num_classes)\n",
    "        \n",
    "    with tf.Session() as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        scores_np = sess.run(scores)\n",
    "        print(scores_np.shape)\n",
    "test_two_layer_fc_functional()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Three-Layer ConvNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ThreeLayerConvNet(tf.keras.Model):\n",
    "    def __init__(self, channel_1, channel_2, num_classes):\n",
    "        super().__init__()\n",
    "        initializer = tf.variance_scaling_initializer(scale=2.0)\n",
    "        self.conv_1 = tf.layers.Conv2D(filters=channel_1, kernel_size=(5, 5), strides=(1, 1), \\\n",
    "                        padding='same', activation=tf.nn.relu, kernel_initializer=initializer)\n",
    "        self.conv_2 = tf.layers.Conv2D(filters=channel_2, kernel_size=(3, 3), strides=(1, 1), \\\n",
    "                        padding='same', activation=tf.nn.relu, kernel_initializer=initializer)\n",
    "        self.fc = tf.layers.Dense(units=num_classes, kernel_initializer=initializer)\n",
    "        \n",
    "    def call(self, x, training=None):\n",
    "        x = self.conv_1(x)\n",
    "        x = self.conv_2(x)\n",
    "        scores = self.fc(tf.layers.flatten(x))\n",
    "        return scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(64, 10)\n"
     ]
    }
   ],
   "source": [
    "def test_ThreeLayerConvNet():\n",
    "    tf.reset_default_graph()\n",
    "    channel_1, channel_2, num_classes = 12, 8, 10\n",
    "    model = ThreeLayerConvNet(channel_1, channel_2, num_classes)\n",
    "    with tf.device(device):\n",
    "        x = tf.zeros((64, 3, 32, 32))\n",
    "        scores = model(x)\n",
    "        \n",
    "    with tf.Session() as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        scores_np = sess.run(scores)\n",
    "        print(scores_np.shape)\n",
    "        \n",
    "test_ThreeLayerConvNet()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train2(model_init_fn, optimizer_init_fn, num_epochs=1):\n",
    "    tf.reset_default_graph()\n",
    "    with tf.device(device):\n",
    "        x = tf.placeholder(tf.float32, [None, 32, 32, 3])\n",
    "        y = tf.placeholder(tf.int32, [None])\n",
    "        is_training = tf.placeholder(tf.bool, name='is_training')\n",
    "        scores = model_init_fn(x, is_training)\n",
    "        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=scores)\n",
    "        loss = tf.reduce_mean(loss)\n",
    "        \n",
    "        optimizer = optimizer_init_fn()\n",
    "        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)\n",
    "        with tf.control_dependencies(update_ops):\n",
    "            train_op = optimizer.minimize(loss)\n",
    "            \n",
    "    with tf.Session() as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        t = 0\n",
    "        for epoch in range(num_epochs):\n",
    "            print('Starting epoch %d' % epoch)\n",
    "            for x_np, y_np in train_dset:\n",
    "                feed_dict = {x: x_np, y: y_np}\n",
    "                loss_np, _ = sess.run([loss, train_op], feed_dict=feed_dict)\n",
    "                if t % print_every == 0:\n",
    "                    print('Iteration %d, loss = %.4f' % (t, loss_np))\n",
    "                    check_accuracy(sess, val_dset, x, scores, is_training=is_training)\n",
    "                    print()\n",
    "                t += 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train a Two-Layer Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting epoch 0\n",
      "Iteration 0, loss = 2.6712\n",
      "Got 136 / 1000 correct (13.60%)\n",
      "\n",
      "Iteration 100, loss = 1.8448\n",
      "Got 381 / 1000 correct (38.10%)\n",
      "\n",
      "Iteration 200, loss = 1.3787\n",
      "Got 407 / 1000 correct (40.70%)\n",
      "\n",
      "Iteration 300, loss = 1.7410\n",
      "Got 399 / 1000 correct (39.90%)\n",
      "\n",
      "Iteration 400, loss = 1.7544\n",
      "Got 431 / 1000 correct (43.10%)\n",
      "\n",
      "Iteration 500, loss = 1.8186\n",
      "Got 454 / 1000 correct (45.40%)\n",
      "\n",
      "Iteration 600, loss = 1.8170\n",
      "Got 427 / 1000 correct (42.70%)\n",
      "\n",
      "Iteration 700, loss = 1.8482\n",
      "Got 447 / 1000 correct (44.70%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "hidden_size, num_classes = 4000, 10\n",
    "learning_rate = 1e-2\n",
    "\n",
    "def model_init_fn(inputs, is_training):\n",
    "    return TwoLayerFC(hidden_size, num_classes)(inputs)\n",
    "\n",
    "def optimizer_init_fn():\n",
    "    return tf.train.GradientDescentOptimizer(learning_rate)\n",
    "\n",
    "train2(model_init_fn, optimizer_init_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train a Two-Layer Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting epoch 0\n",
      "Iteration 0, loss = 2.8083\n",
      "Got 128 / 1000 correct (12.80%)\n",
      "\n",
      "Iteration 100, loss = 1.9207\n",
      "Got 384 / 1000 correct (38.40%)\n",
      "\n",
      "Iteration 200, loss = 1.4632\n",
      "Got 411 / 1000 correct (41.10%)\n",
      "\n",
      "Iteration 300, loss = 1.7852\n",
      "Got 369 / 1000 correct (36.90%)\n",
      "\n",
      "Iteration 400, loss = 1.7783\n",
      "Got 425 / 1000 correct (42.50%)\n",
      "\n",
      "Iteration 500, loss = 1.7788\n",
      "Got 434 / 1000 correct (43.40%)\n",
      "\n",
      "Iteration 600, loss = 1.7856\n",
      "Got 433 / 1000 correct (43.30%)\n",
      "\n",
      "Iteration 700, loss = 1.8244\n",
      "Got 434 / 1000 correct (43.40%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "hidden_size, num_classes = 4000, 10\n",
    "learning_rate = 1e-2\n",
    "\n",
    "def model_init_fn(inputs, is_training):\n",
    "    return two_layer_fc_functional(inputs, hidden_size, num_classes)\n",
    "\n",
    "def optimizer_init_fn():\n",
    "    return tf.train.GradientDescentOptimizer(learning_rate)\n",
    "\n",
    "train2(model_init_fn, optimizer_init_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train a Three-Layer ConvNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting epoch 0\n",
      "Iteration 0, loss = 2.6920\n",
      "Got 90 / 1000 correct (9.00%)\n",
      "\n",
      "Iteration 100, loss = 1.6495\n",
      "Got 414 / 1000 correct (41.40%)\n",
      "\n",
      "Iteration 200, loss = 1.4028\n",
      "Got 486 / 1000 correct (48.60%)\n",
      "\n",
      "Iteration 300, loss = 1.3792\n",
      "Got 506 / 1000 correct (50.60%)\n",
      "\n",
      "Iteration 400, loss = 1.2277\n",
      "Got 504 / 1000 correct (50.40%)\n",
      "\n",
      "Iteration 500, loss = 1.5428\n",
      "Got 540 / 1000 correct (54.00%)\n",
      "\n",
      "Iteration 600, loss = 1.4736\n",
      "Got 567 / 1000 correct (56.70%)\n",
      "\n",
      "Iteration 700, loss = 1.3710\n",
      "Got 577 / 1000 correct (57.70%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "learning_rate = 3e-3\n",
    "channel_1, channel_2, num_classes = 32, 16, 10\n",
    "\n",
    "def model_init_fn(inputs, is_training):\n",
    "    return ThreeLayerConvNet(channel_1, channel_2, num_classes)(inputs)\n",
    "\n",
    "def optimizer_init_fn():\n",
    "    return tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9, use_nesterov=True)\n",
    "\n",
    "train2(model_init_fn, optimizer_init_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Keras Sequential API\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Two-Layer Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting epoch 0\n",
      "Iteration 0, loss = 2.7439\n",
      "Got 116 / 1000 correct (11.60%)\n",
      "\n",
      "Iteration 100, loss = 1.8813\n",
      "Got 375 / 1000 correct (37.50%)\n",
      "\n",
      "Iteration 200, loss = 1.3769\n",
      "Got 403 / 1000 correct (40.30%)\n",
      "\n",
      "Iteration 300, loss = 1.7591\n",
      "Got 403 / 1000 correct (40.30%)\n",
      "\n",
      "Iteration 400, loss = 1.7104\n",
      "Got 413 / 1000 correct (41.30%)\n",
      "\n",
      "Iteration 500, loss = 1.7225\n",
      "Got 450 / 1000 correct (45.00%)\n",
      "\n",
      "Iteration 600, loss = 1.7909\n",
      "Got 440 / 1000 correct (44.00%)\n",
      "\n",
      "Iteration 700, loss = 1.8043\n",
      "Got 452 / 1000 correct (45.20%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "learning_rate = 1e-2\n",
    "\n",
    "def model_init_fn(inputs, is_training):\n",
    "    input_shape = (32, 32, 3)\n",
    "    hidden_layer_size, num_classes = 4000, 10\n",
    "    initializer = tf.variance_scaling_initializer(scale=2.0)\n",
    "    \n",
    "    layers = [\n",
    "        tf.layers.Flatten(input_shape=input_shape),\n",
    "        tf.layers.Dense(hidden_layer_size, activation=tf.nn.relu, kernel_initializer=initializer),\n",
    "        tf.layers.Dense(num_classes, kernel_initializer=initializer)\n",
    "    ]\n",
    "    \n",
    "    return tf.keras.Sequential(layers)(inputs)\n",
    "\n",
    "def optimizer_init_fn():\n",
    "    return tf.train.GradientDescentOptimizer(learning_rate)\n",
    "\n",
    "train2(model_init_fn, optimizer_init_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Three-Layer ConvNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting epoch 0\n",
      "Iteration 0, loss = 2.4473\n",
      "Got 125 / 1000 correct (12.50%)\n",
      "\n",
      "Iteration 100, loss = 1.8573\n",
      "Got 372 / 1000 correct (37.20%)\n",
      "\n",
      "Iteration 200, loss = 1.4919\n",
      "Got 432 / 1000 correct (43.20%)\n",
      "\n",
      "Iteration 300, loss = 1.6237\n",
      "Got 455 / 1000 correct (45.50%)\n",
      "\n",
      "Iteration 400, loss = 1.5683\n",
      "Got 476 / 1000 correct (47.60%)\n",
      "\n",
      "Iteration 500, loss = 1.6352\n",
      "Got 480 / 1000 correct (48.00%)\n",
      "\n",
      "Iteration 600, loss = 1.5678\n",
      "Got 501 / 1000 correct (50.10%)\n",
      "\n",
      "Iteration 700, loss = 1.5124\n",
      "Got 512 / 1000 correct (51.20%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def model_init_fn(inputs, is_training):\n",
    "    input_shape = (32, 32, 3)\n",
    "    initializer = tf.variance_scaling_initializer(scale=2.0)\n",
    "    \n",
    "    layers = [\n",
    "        tf.layers.Conv2D(input_shape=input_shape, filters=32, kernel_size=(5, 5), strides=(1, 1),\\\n",
    "                        padding='same', activation=tf.nn.relu, kernel_initializer=initializer),\n",
    "        tf.layers.Conv2D(filters=16, kernel_size=(3, 3), strides=(1, 1),\\\n",
    "                        padding='same', activation=tf.nn.relu, kernel_initializer=initializer),\n",
    "        tf.layers.Flatten(),\n",
    "        tf.layers.Dense(units=10, kernel_initializer=initializer)\n",
    "    ]\n",
    "    \n",
    "    return tf.keras.Sequential(layers)(inputs)\n",
    "\n",
    "learning_rate = 5e-4\n",
    "def optimizer_init_fn():\n",
    "    return tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9, use_nesterov=True)\n",
    "\n",
    "train2(model_init_fn, optimizer_init_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
